{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from abp_artf_plugin.pipeline import ProcessingPipeline\n",
    "import numpy as np\n",
    "import wfdb\n",
    "import random\n",
    "from utils.stats import calculate_cls_metrics\n",
    "from scipy.signal import butter, filtfilt\n",
    "from scipy.signal import resample\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rcParams, font_manager as fm\n",
    "arial_font_path = 'fonts/Arial.ttf'\n",
    "fm.fontManager.addfont(arial_font_path)\n",
    "arial_font = fm.FontProperties(fname=arial_font_path)\n",
    "rcParams['font.family'] = arial_font.get_name()\n",
    "\n",
    "\n",
    "import plotly.graph_objects as go\n",
    "from plotly.subplots import make_subplots\n",
    "import plotly.io as pio\n",
    "\n",
    "#Set seed for reproducibility\n",
    "seed = 224\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_data_and_label(data, label, save_memory=True):\n",
    "    # Create a subplot with two rows\n",
    "    fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.1,\n",
    "                        subplot_titles=(\"Processed Data\", \"Label\"))\n",
    "\n",
    "    # Plot the processed data in the first subplot\n",
    "    fig.add_trace(go.Scatter(\n",
    "        x=np.arange(len(data)),\n",
    "        y=data,\n",
    "        mode='lines',\n",
    "        name='Processed Data',\n",
    "    ), row=1, col=1)\n",
    "\n",
    "    # Plot the label in the second subplot\n",
    "    fig.add_trace(go.Scatter(\n",
    "        x=np.arange(len(label)),\n",
    "        y=label,\n",
    "        mode='lines',\n",
    "        name='Label',\n",
    "    ), row=2, col=1)\n",
    "\n",
    "    # Update layout for better visualization\n",
    "    fig.update_layout(\n",
    "        title=\"Processed Data and Label\",\n",
    "        xaxis_title=\"Index\",\n",
    "        yaxis_title=\"Data\",\n",
    "        xaxis2_title=\"Index\",\n",
    "        yaxis2_title=\"Label\",\n",
    "        height=600\n",
    "    )\n",
    "\n",
    "    # Show the plot\n",
    "    if save_memory:\n",
    "        pio.show()\n",
    "    else:\n",
    "        fig.show()\n",
    "        \n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = ProcessingPipeline()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test on MIMIC-III dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# file_path:\n",
    "record = wfdb.rdrecord('../datasets/physionet.org/files/mimic3wdb-matched/1.0/p00/p000020/p000020-2183-04-28-17-47') \n",
    "display(record.__dict__)\n",
    "signals = record.p_signal\n",
    "display(signals.shape)\n",
    "\n",
    "wfdb.plot_wfdb(record=record, title='p000020')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fs = 125\n",
    "skip_s = 60*5\n",
    "duration_s = 60*60*3\n",
    "data_abp = signals[fs*skip_s:fs*(skip_s+duration_s), 2]\n",
    "print(data_abp.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# high risk memory line!\n",
    "data, label, mse_list = p.process(data_abp, fs)\n",
    "print(data.shape)\n",
    "print(label.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plot_data_and_label(data_abp, label.astype(int), save_memory=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Arterial Blood Pressure Hypertension"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_abp = np.load('buffer/data_abp.npy')\n",
    "data_label = np.load('buffer/data_abp_label.npy')\n",
    "\n",
    "# masked data with label 0\n",
    "data_abp_masked = data_abp.copy()\n",
    "data_abp_masked[data_label == 1] = 0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate the hypertension period\n",
    "from scipy.signal import find_peaks\n",
    "\n",
    "def find_sbp_dbp(abp_series):\n",
    "    \"\"\"\n",
    "    Find Systolic Blood Pressure (SBP) and Diastolic Blood Pressure (DBP) from a high-resolution ABP series.\n",
    "    \n",
    "    Parameters:\n",
    "    - abp_series: 1D numpy array, high-resolution ABP data.\n",
    "    \n",
    "    Returns:\n",
    "    - sbp_values: List of detected Systolic BP values (peaks).\n",
    "    - dbp_values: List of detected Diastolic BP values (troughs).\n",
    "    \"\"\"\n",
    "    # Find systolic peaks (SBP) - local maxima\n",
    "    sbp_peaks, _ = find_peaks(abp_series, distance=50)  # Distance avoids small variations\n",
    "    sbp_values = abp_series[sbp_peaks]\n",
    "    \n",
    "    # Find diastolic troughs (DBP) - local minima\n",
    "    dbp_troughs, _ = find_peaks(-abp_series, distance=50)\n",
    "    dbp_values = abp_series[dbp_troughs]\n",
    "    return sbp_values, dbp_values\n",
    "\n",
    "def count_hypertension_events(sbp, dbp):\n",
    "    n_sbp_hypertension = np.sum(sbp > 140)\n",
    "    n_dbp_hypertension = np.sum(dbp > 90)\n",
    "    return n_sbp_hypertension + n_dbp_hypertension\n",
    "\n",
    "def cnt_hypertension(abp):\n",
    "    sbp, dbp = find_sbp_dbp(abp)\n",
    "    return count_hypertension_events(sbp, dbp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = go.Figure()\n",
    "fig.add_trace(go.Scatter(x=np.arange(data_abp.shape[0]), y=data_abp, mode='lines', name='ABP'))\n",
    "pio.show()\n",
    "del fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fs = 125\n",
    "skip_s = 60*5\n",
    "duration_s = 60*60*3\n",
    "data_pap = signals[fs*skip_s:fs*(skip_s+duration_s), 3]\n",
    "print(data_pap.shape)\n",
    "pdata_pap, label_pap, mse_list_pap = p.process(data_pap, fs)\n",
    "print(data_pap.shape)\n",
    "print(label_pap.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_data_and_label(data_pap, label_pap)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test on Our Dataset (Please write own dataloader function!)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load our data \n",
    "# from data_provider import data_factory\n",
    "# data, label = data_factory.npy_provider('240405', flag='test') # use help function\n",
    "display(data.shape)\n",
    "data, item_labels, mse_list = p.process(data, 120)\n",
    "items_per_sample = np.mean(item_labels, axis=1)\n",
    "print(calculate_cls_metrics(label, items_per_sample))\n",
    "# {'accuracy': 0.95, 'f1_score': 0.9480249480249481, 'sensitivity': 0.912, 'specificity': 0.988}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sampling Frequency Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import sys\n",
    "# sys.path.append(\"../BP-artefact-removal\")\n",
    "# from data_provider import data_factory\n",
    "\n",
    "# data, label = data_factory.npy_provider('240405', flag='test')\n",
    "# print(data.shape, label.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def resample_to(data, ori_fs, target_fs):\n",
    "    assert data.ndim == 2, \"Data must be 2D, such as N*1200\"\n",
    "    \n",
    "    num_samples = int((target_fs / ori_fs) * data.shape[1])\n",
    "    resampled = np.zeros((data.shape[0], num_samples))\n",
    "    for i in range(data.shape[0]):\n",
    "        resampled[i, :] = resample(data[i, :], num_samples)\n",
    "    return resampled\n",
    "\n",
    "# Target sampling rates\n",
    "target_fs_list = [50, 75, 100, 120, 125, 150, 175, 200, 240]\n",
    "\n",
    "for target_fs in target_fs_list:\n",
    "    if target_fs != 120:\n",
    "        resampled_data = resample_to(data, 120, target_fs)\n",
    "    else:\n",
    "        resampled_data = data\n",
    "    print(f\"Target FS: {target_fs} Hz, Resampled Shape: {resampled_data.shape}\")\n",
    "    tit = time.time()\n",
    "    data_recon, item_labels, mse_list = p.process(resampled_data, target_fs)\n",
    "    tot = time.time()\n",
    "    items_per_sample = np.around(np.mean(item_labels, axis=1))\n",
    "    print(tot-tit)\n",
    "    print(calculate_cls_metrics(label, items_per_sample))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bpenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
